# -*- coding: utf-8 -*-

"""
分析来源日志，如果日志包含关键词则将日志存储到obs，并发送消息到SMN中
Analyze source log, if log contain some keywords, it will store logs at obs,send notification to SMN. 
 LTS 触发器触发函数时，event 结构：
 If LTS trigger invoke this function, event will be:
    {
        "lts": {
        "data": "xxxxxxxxxxxxx"
        }
    }
  data是经过base64 加密的，使用前需要解密, 
  data is encoded by base64, decode it before use it. 
  
 如果是用函数 context 获取ak sk 请注意配置的委托需要有SMN OBS 的权限。 
 if use context to obtain ak sk, note that you should config agency with SMN,OBS permission for this 
 function. 
  
  
"""
import json
import time
import random
import base64
from obs import ObsClient
from huaweicloudsdkcore.auth.credentials import BasicCredentials
from huaweicloudsdksmn.v2.region.smn_region import SmnRegion
from huaweicloudsdksmn.v2 import SmnClient, PublishMessageRequestBody, PublishMessageRequest


LOGGER_PREFIX = "log"  # will upload to this forldar
ALARM_LOG_KEY = ["WARN", "WRN", "ERROR", "ERR"]
SMN_SUBJECT = "FunctionGraph Log Analysis Alarm"


def handler(event, context):
    log = context.getLogger()
    # 校验参数 verify params
    obs_address = context.getUserData("obs_address")
    obs_bucket = context.getUserData("obs_store_bucket")
    if not obs_address or not obs_bucket:
        raise Exception("Please configure obs environment variable")
    if not context.getAccessKey() or not context.getSecretKey():
        raise Exception(
            "Can not get accessKey or secretKey. Please check agency")
    if not context.getUserData('smn_urn'):
        raise Exception("Please configure SMN  environment variable")
    # 从lts获取日志 Obtains the data from lts logs.
    encodingData = event["lts"]["data"]
    data_based = base64.b64decode(encodingData)
    data = json.loads(data_based)
    log.info(
        f"log group id [{data['log_group_id']}], topic id [{data['log_topic_id']}] ")
    print("data: ",data )
    logs = data["logs"]
    # 检查是有需告警日志 Check whether the alarm log existed
    alarm_logs = analyze_logs(logs)
    print("==========================")
    print(logs)
    print("==========================")
    print(alarm_logs)
    if len(alarm_logs) == 0:
        log.info("no need alarm")
        return "no need alarm"

    # 序列化日志 searlize logs
    object_name = gen_log_name()
    logs_str = json.dumps(alarm_logs).replace('\\', '')
    # 上传日到obs upload file to obs bucket
    # 需要委托有obs权限 agency should have obs permission
    obs_clinet = new_obs_client(context, obs_address)
    res = upload_content_to_obs(obs_clinet, obs_bucket, logs_str, object_name)
    log_obs_path = f"check Full log at obs [{object_name}]"
    if not res:
        log_obs_path = f"check full log at FGS log"
    # 发送告警 send smn message
    smn_client = new_smn_client(context)  # 需要委托有smn权限 smn permission
    send_smn_msg(context, smn_client, logs_str, log_obs_path)
    # 可以考虑最后将异常抛出这样可以在函数页面看见执行错误
    # finally, can throw error so that can see fail metric at FG log module
    # if not res: raise Exception(" Fail to upload obs")
    return 'alarm success'


# 分析日志 iter logs , if it contain key words store it in list then return
def analyze_logs(logs):
    alarm_logs = []
    if type(logs) != list:
        logs = json.loads(logs) 
    for log in logs:
        log_str = json.dumps(log)
        print("one lg :", log_str)
        for item in ALARM_LOG_KEY:
            if item in log_str:
                alarm_logs.append(log_str)
                break
    return alarm_logs


# 生成存储日志名称 ganerate a log name
def gen_log_name():
    t = time.strftime("%Y%m%d%H%M%S")
    return f"{LOGGER_PREFIX}/log-{t}-{random.randint(100000,1000000)}.log"


def new_obs_client(context, obs_server):
    return ObsClient(
        access_key_id=context.getAccessKey(),
        secret_access_key=context.getSecretKey(),
        server=obs_server
    )


def upload_content_to_obs(client: ObsClient, bucket_name, content, obj_name):
    try:
        resp = client.putContent(bucket_name, obj_name, content=content)
        if resp.status > 300:  # 如果上传失败就打印来源日志 if fail to upload, print alarm logs
            print('errorCode:', resp.errorCode)
            print('errorMessage:', resp.errorMessage)
            print("=========source log============")
            print(content)
            return False
    except:
        import traceback
        print(traceback.format_exc())
        print("=========source log============")
        print(content)
        return False
    return True


def new_smn_client(context):
    my_region = context.getUserData("smn_urn").split(':')[2]
    credentials = BasicCredentials(
        context.getAccessKey(), context.getSecretKey(),context.getProjectID())
    client = SmnClient.new_builder() \
        .with_credentials(credentials) \
        .with_region(SmnRegion.value_of(my_region)) \
        .build()
    return client


def send_smn_msg(context, client, logs_str, log_obs_path):
    print("start to send")
    request = PublishMessageRequest()
    request.topic_urn = context.getUserData("smn_urn")
    request.body = PublishMessageRequestBody(
        subject=SMN_SUBJECT,
        message=f"{SMN_SUBJECT} | {log_obs_path}  : {logs_str}"
    )
    resp = client.publish_message(request)  # 可能会抛出异常 may throw ERROR here
    print("smn esp :", resp)